"""
original from https://github.com/vchoutas/smplx
modified by Vassilis and Yao
"""

import pickle

import numpy as np
import torch
import torch.nn as nn

from .lbs import (
    JointsFromVerticesSelector,
    Struct,
    find_dynamic_lmk_idx_and_bcoords,
    lbs,
    to_np,
    to_tensor,
    vertices2landmarks,
)

# SMPLX
J14_NAMES = [
    "right_ankle", "right_knee", "right_hip", "left_hip", "left_knee", "left_ankle", "right_wrist",
    "right_elbow", "right_shoulder", "left_shoulder", "left_elbow", "left_wrist", "neck", "head"
]
SMPLX_names = [
    "pelvis", "left_hip", "right_hip", "spine1", "left_knee", "right_knee", "spine2", "left_ankle",
    "right_ankle", "spine3", "left_foot", "right_foot", "neck", "left_collar", "right_collar",
    "head", "left_shoulder", "right_shoulder", "left_elbow", "right_elbow", "left_wrist",
    "right_wrist", "jaw", "left_eye_smplx", "right_eye_smplx", "left_index1", "left_index2",
    "left_index3", "left_middle1", "left_middle2", "left_middle3", "left_pinky1", "left_pinky2",
    "left_pinky3", "left_ring1", "left_ring2", "left_ring3", "left_thumb1", "left_thumb2",
    "left_thumb3", "right_index1", "right_index2", "right_index3", "right_middle1", "right_middle2",
    "right_middle3", "right_pinky1", "right_pinky2", "right_pinky3", "right_ring1", "right_ring2",
    "right_ring3", "right_thumb1", "right_thumb2", "right_thumb3", "right_eye_brow1",
    "right_eye_brow2", "right_eye_brow3", "right_eye_brow4", "right_eye_brow5", "left_eye_brow5",
    "left_eye_brow4", "left_eye_brow3", "left_eye_brow2", "left_eye_brow1", "nose1", "nose2",
    "nose3", "nose4", "right_nose_2", "right_nose_1", "nose_middle", "left_nose_1", "left_nose_2",
    "right_eye1", "right_eye2", "right_eye3", "right_eye4", "right_eye5", "right_eye6", "left_eye4",
    "left_eye3", "left_eye2", "left_eye1", "left_eye6", "left_eye5", "right_mouth_1",
    "right_mouth_2", "right_mouth_3", "mouth_top", "left_mouth_3", "left_mouth_2", "left_mouth_1",
    "left_mouth_5", "left_mouth_4", "mouth_bottom", "right_mouth_4", "right_mouth_5", "right_lip_1",
    "right_lip_2", "lip_top", "left_lip_2", "left_lip_1", "left_lip_3", "lip_bottom", "right_lip_3",
    "right_contour_1", "right_contour_2", "right_contour_3", "right_contour_4", "right_contour_5",
    "right_contour_6", "right_contour_7", "right_contour_8", "contour_middle", "left_contour_8",
    "left_contour_7", "left_contour_6", "left_contour_5", "left_contour_4", "left_contour_3",
    "left_contour_2", "left_contour_1", "head_top", "left_big_toe", "left_ear", "left_eye",
    "left_heel", "left_index", "left_middle", "left_pinky", "left_ring", "left_small_toe",
    "left_thumb", "nose", "right_big_toe", "right_ear", "right_eye", "right_heel", "right_index",
    "right_middle", "right_pinky", "right_ring", "right_small_toe", "right_thumb"
]
extra_names = [
    "head_top", "left_big_toe", "left_ear", "left_eye", "left_heel", "left_index", "left_middle",
    "left_pinky", "left_ring", "left_small_toe", "left_thumb", "nose", "right_big_toe", "right_ear",
    "right_eye", "right_heel", "right_index", "right_middle", "right_pinky", "right_ring",
    "right_small_toe", "right_thumb"
]
SMPLX_names += extra_names

part_indices = {}
part_indices["body"] = np.array([
    0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 123,
    124, 125, 126, 127, 132, 134, 135, 136, 137, 138, 143
])
part_indices["torso"] = np.array([
    0, 1, 2, 3, 6, 9, 12, 13, 14, 15, 16, 17, 18, 19, 22, 23, 24, 55, 56, 57, 58, 59, 76, 77, 78,
    79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101,
    102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120,
    121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139,
    140, 141, 142, 143, 144
])
part_indices["head"] = np.array([
    12, 15, 22, 23, 24, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73,
    74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97,
    98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116,
    117, 118, 119, 120, 121, 122, 123, 125, 126, 134, 136, 137
])
part_indices["face"] = np.array([
    55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78,
    79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101,
    102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120,
    121, 122
])
part_indices["upper"] = np.array([
    12, 13, 14, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75,
    76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99,
    100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118,
    119, 120, 121, 122
])
part_indices["hand"] = np.array([
    20, 21, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46,
    47, 48, 49, 50, 51, 52, 53, 54, 128, 129, 130, 131, 133, 139, 140, 141, 142, 144
])
part_indices["left_hand"] = np.array([
    20, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 128, 129, 130, 131, 133
])
part_indices["right_hand"] = np.array([
    21, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 139, 140, 141, 142, 144
])
# kinematic tree
head_kin_chain = [15, 12, 9, 6, 3, 0]

# --smplx joints
# 00 - Global
# 01 - L_Thigh
# 02 - R_Thigh
# 03 - Spine
# 04 - L_Calf
# 05 - R_Calf
# 06 - Spine1
# 07 - L_Foot
# 08 - R_Foot
# 09 - Spine2
# 10 - L_Toes
# 11 - R_Toes
# 12 - Neck
# 13 - L_Shoulder
# 14 - R_Shoulder
# 15 - Head
# 16 - L_UpperArm
# 17 - R_UpperArm
# 18 - L_ForeArm
# 19 - R_ForeArm
# 20 - L_Hand
# 21 - R_Hand
# 22 - Jaw
# 23 - L_Eye
# 24 - R_Eye


class SMPLX(nn.Module):
    """
    Given smplx parameters, this class generates a differentiable SMPLX function
    which outputs a mesh and 3D joints
    """
    def __init__(self, config):
        super(SMPLX, self).__init__()
        # print("creating the SMPLX Decoder")
        ss = np.load(config.smplx_model_path, allow_pickle=True)
        smplx_model = Struct(**ss)

        self.dtype = torch.float32
        self.register_buffer(
            "faces_tensor",
            to_tensor(to_np(smplx_model.f, dtype=np.int64), dtype=torch.long),
        )
        # The vertices of the template model
        self.register_buffer(
            "v_template", to_tensor(to_np(smplx_model.v_template), dtype=self.dtype)
        )
        # The shape components and expression
        # expression space is the same as FLAME
        shapedirs = to_tensor(to_np(smplx_model.shapedirs), dtype=self.dtype)
        shapedirs = torch.cat(
            [
                shapedirs[:, :, :config.n_shape],
                shapedirs[:, :, 300:300 + config.n_exp],
            ],
            2,
        )
        self.register_buffer("shapedirs", shapedirs)
        # The pose components
        num_pose_basis = smplx_model.posedirs.shape[-1]
        posedirs = np.reshape(smplx_model.posedirs, [-1, num_pose_basis]).T
        self.register_buffer("posedirs", to_tensor(to_np(posedirs), dtype=self.dtype))
        self.register_buffer(
            "J_regressor", to_tensor(to_np(smplx_model.J_regressor), dtype=self.dtype)
        )
        parents = to_tensor(to_np(smplx_model.kintree_table[0])).long()
        parents[0] = -1
        self.register_buffer("parents", parents)
        self.register_buffer("lbs_weights", to_tensor(to_np(smplx_model.weights), dtype=self.dtype))
        # for face keypoints
        self.register_buffer(
            "lmk_faces_idx", torch.tensor(smplx_model.lmk_faces_idx, dtype=torch.long)
        )
        self.register_buffer(
            "lmk_bary_coords",
            torch.tensor(smplx_model.lmk_bary_coords, dtype=self.dtype),
        )
        self.register_buffer(
            "dynamic_lmk_faces_idx",
            torch.tensor(smplx_model.dynamic_lmk_faces_idx, dtype=torch.long),
        )
        self.register_buffer(
            "dynamic_lmk_bary_coords",
            torch.tensor(smplx_model.dynamic_lmk_bary_coords, dtype=self.dtype),
        )
        # pelvis to head, to calculate head yaw angle, then find the dynamic landmarks
        self.register_buffer("head_kin_chain", torch.tensor(head_kin_chain, dtype=torch.long))

        # -- initialize parameters
        # shape and expression
        self.register_buffer(
            "shape_params",
            nn.Parameter(torch.zeros([1, config.n_shape], dtype=self.dtype), requires_grad=False),
        )
        self.register_buffer(
            "expression_params",
            nn.Parameter(torch.zeros([1, config.n_exp], dtype=self.dtype), requires_grad=False),
        )
        # pose: represented as rotation matrx [number of joints, 3, 3]
        self.register_buffer(
            "global_pose",
            nn.Parameter(
                torch.eye(3, dtype=self.dtype).unsqueeze(0).repeat(1, 1, 1),
                requires_grad=False,
            ),
        )
        self.register_buffer(
            "head_pose",
            nn.Parameter(
                torch.eye(3, dtype=self.dtype).unsqueeze(0).repeat(1, 1, 1),
                requires_grad=False,
            ),
        )
        self.register_buffer(
            "neck_pose",
            nn.Parameter(
                torch.eye(3, dtype=self.dtype).unsqueeze(0).repeat(1, 1, 1),
                requires_grad=False,
            ),
        )
        self.register_buffer(
            "jaw_pose",
            nn.Parameter(
                torch.eye(3, dtype=self.dtype).unsqueeze(0).repeat(1, 1, 1),
                requires_grad=False,
            ),
        )
        self.register_buffer(
            "eye_pose",
            nn.Parameter(
                torch.eye(3, dtype=self.dtype).unsqueeze(0).repeat(2, 1, 1),
                requires_grad=False,
            ),
        )
        self.register_buffer(
            "body_pose",
            nn.Parameter(
                torch.eye(3, dtype=self.dtype).unsqueeze(0).repeat(21, 1, 1),
                requires_grad=False,
            ),
        )
        self.register_buffer(
            "left_hand_pose",
            nn.Parameter(
                torch.eye(3, dtype=self.dtype).unsqueeze(0).repeat(15, 1, 1),
                requires_grad=False,
            ),
        )
        self.register_buffer(
            "right_hand_pose",
            nn.Parameter(
                torch.eye(3, dtype=self.dtype).unsqueeze(0).repeat(15, 1, 1),
                requires_grad=False,
            ),
        )

        if config.extra_joint_path:
            self.extra_joint_selector = JointsFromVerticesSelector(fname=config.extra_joint_path)
        self.use_joint_regressor = True
        self.keypoint_names = SMPLX_names
        if self.use_joint_regressor:
            with open(config.j14_regressor_path, "rb") as f:
                j14_regressor = pickle.load(f, encoding="latin1")
            source = []
            target = []
            for idx, name in enumerate(self.keypoint_names):
                if name in J14_NAMES:
                    source.append(idx)
                    target.append(J14_NAMES.index(name))
            source = np.asarray(source)
            target = np.asarray(target)
            self.register_buffer("source_idxs", torch.from_numpy(source))
            self.register_buffer("target_idxs", torch.from_numpy(target))
            self.register_buffer(
                "extra_joint_regressor",
                torch.from_numpy(j14_regressor).to(torch.float32)
            )
            self.part_indices = part_indices

    def forward(
        self,
        shape_params=None,
        expression_params=None,
        global_pose=None,
        body_pose=None,
        jaw_pose=None,
        eye_pose=None,
        left_hand_pose=None,
        right_hand_pose=None,
    ):
        """
        Args:
            shape_params: [N, number of shape parameters]
            expression_params: [N, number of expression parameters]
            global_pose: pelvis pose, [N, 1, 3, 3]
            body_pose: [N, 21, 3, 3]
            jaw_pose: [N, 1, 3, 3]
            eye_pose: [N, 2, 3, 3]
            left_hand_pose: [N, 15, 3, 3]
            right_hand_pose: [N, 15, 3, 3]
        Returns:
            vertices: [N, number of vertices, 3]
            landmarks: [N, number of landmarks (68 face keypoints), 3]
            joints: [N, number of smplx joints (145), 3]
        """
        if shape_params is None:
            batch_size = global_pose.shape[0]
            shape_params = self.shape_params.expand(batch_size, -1)
        else:
            batch_size = shape_params.shape[0]
        if expression_params is None:
            expression_params = self.expression_params.expand(batch_size, -1)
        if global_pose is None:
            global_pose = self.global_pose.unsqueeze(0).expand(batch_size, -1, -1, -1)
        if body_pose is None:
            body_pose = self.body_pose.unsqueeze(0).expand(batch_size, -1, -1, -1)
        if jaw_pose is None:
            jaw_pose = self.jaw_pose.unsqueeze(0).expand(batch_size, -1, -1, -1)
        if eye_pose is None:
            eye_pose = self.eye_pose.unsqueeze(0).expand(batch_size, -1, -1, -1)
        if left_hand_pose is None:
            left_hand_pose = self.left_hand_pose.unsqueeze(0).expand(batch_size, -1, -1, -1)
        if right_hand_pose is None:
            right_hand_pose = self.right_hand_pose.unsqueeze(0).expand(batch_size, -1, -1, -1)

        shape_components = torch.cat([shape_params, expression_params], dim=1)
        full_pose = torch.cat(
            [
                global_pose,
                body_pose,
                jaw_pose,
                eye_pose,
                left_hand_pose,
                right_hand_pose,
            ],
            dim=1,
        )
        template_vertices = self.v_template.unsqueeze(0).expand(batch_size, -1, -1)
        # smplx
        vertices, joints = lbs(
            shape_components,
            full_pose,
            template_vertices,
            self.shapedirs,
            self.posedirs,
            self.J_regressor,
            self.parents,
            self.lbs_weights,
            dtype=self.dtype,
            pose2rot=False,
        )
        # face dynamic landmarks
        lmk_faces_idx = self.lmk_faces_idx.unsqueeze(dim=0).expand(batch_size, -1)
        lmk_bary_coords = self.lmk_bary_coords.unsqueeze(dim=0).expand(batch_size, -1, -1)
        dyn_lmk_faces_idx, dyn_lmk_bary_coords = find_dynamic_lmk_idx_and_bcoords(
            vertices,
            full_pose,
            self.dynamic_lmk_faces_idx,
            self.dynamic_lmk_bary_coords,
            self.head_kin_chain,
        )
        lmk_faces_idx = torch.cat([lmk_faces_idx, dyn_lmk_faces_idx], 1)
        lmk_bary_coords = torch.cat([lmk_bary_coords, dyn_lmk_bary_coords], 1)
        landmarks = vertices2landmarks(vertices, self.faces_tensor, lmk_faces_idx, lmk_bary_coords)

        final_joint_set = [joints, landmarks]
        if hasattr(self, "extra_joint_selector"):
            # Add any extra joints that might be needed
            extra_joints = self.extra_joint_selector(vertices, self.faces_tensor)
            final_joint_set.append(extra_joints)
        # Create the final joint set
        joints = torch.cat(final_joint_set, dim=1)
        # if self.use_joint_regressor:
        #     reg_joints = torch.einsum("ji,bik->bjk",
        #                               self.extra_joint_regressor, vertices)
        #     joints[:, self.source_idxs] = reg_joints[:, self.target_idxs]

        return vertices, landmarks, joints

    def pose_abs2rel(self, global_pose, body_pose, abs_joint="head"):
        """change absolute pose to relative pose
        Basic knowledge for SMPLX kinematic tree:
                absolute pose = parent pose * relative pose
        Here, pose must be represented as rotation matrix (batch_sizexnx3x3)
        """
        if abs_joint == "head":
            # Pelvis -> Spine 1, 2, 3 -> Neck -> Head
            kin_chain = [15, 12, 9, 6, 3, 0]
        elif abs_joint == "neck":
            # Pelvis -> Spine 1, 2, 3 -> Neck -> Head
            kin_chain = [12, 9, 6, 3, 0]
        elif abs_joint == "right_wrist":
            # Pelvis -> Spine 1, 2, 3 -> right Collar -> right shoulder
            # -> right elbow -> right wrist
            kin_chain = [21, 19, 17, 14, 9, 6, 3, 0]
        elif abs_joint == "left_wrist":
            # Pelvis -> Spine 1, 2, 3 -> Left Collar -> Left shoulder
            # -> Left elbow -> Left wrist
            kin_chain = [20, 18, 16, 13, 9, 6, 3, 0]
        else:
            raise NotImplementedError(f"pose_abs2rel does not support: {abs_joint}")

        batch_size = global_pose.shape[0]
        dtype = global_pose.dtype
        device = global_pose.device
        full_pose = torch.cat([global_pose, body_pose], dim=1)
        rel_rot_mat = (
            torch.eye(3, device=device, dtype=dtype).unsqueeze_(dim=0).repeat(batch_size, 1, 1)
        )
        for idx in kin_chain[1:]:
            rel_rot_mat = torch.bmm(full_pose[:, idx], rel_rot_mat)

        # This contains the absolute pose of the parent
        abs_parent_pose = rel_rot_mat.detach()
        # Let's assume that in the input this specific joint is predicted as an absolute value
        abs_joint_pose = body_pose[:, kin_chain[0] - 1]
        # abs_head = parents(abs_neck) * rel_head ==> rel_head = abs_neck.T * abs_head
        rel_joint_pose = torch.matmul(
            abs_parent_pose.reshape(-1, 3, 3).transpose(1, 2),
            abs_joint_pose.reshape(-1, 3, 3),
        )
        # Replace the new relative pose
        body_pose[:, kin_chain[0] - 1, :, :] = rel_joint_pose
        return body_pose

    def pose_rel2abs(self, global_pose, body_pose, abs_joint="head"):
        """change relative pose to absolute pose
        Basic knowledge for SMPLX kinematic tree:
                absolute pose = parent pose * relative pose
        Here, pose must be represented as rotation matrix (batch_sizexnx3x3)
        """
        full_pose = torch.cat([global_pose, body_pose], dim=1)

        if abs_joint == "head":
            # Pelvis -> Spine 1, 2, 3 -> Neck -> Head
            kin_chain = [15, 12, 9, 6, 3, 0]
        elif abs_joint == "neck":
            # Pelvis -> Spine 1, 2, 3 -> Neck -> Head
            kin_chain = [12, 9, 6, 3, 0]
        elif abs_joint == "right_wrist":
            # Pelvis -> Spine 1, 2, 3 -> right Collar -> right shoulder
            # -> right elbow -> right wrist
            kin_chain = [21, 19, 17, 14, 9, 6, 3, 0]
        elif abs_joint == "left_wrist":
            # Pelvis -> Spine 1, 2, 3 -> Left Collar -> Left shoulder
            # -> Left elbow -> Left wrist
            kin_chain = [20, 18, 16, 13, 9, 6, 3, 0]
        else:
            raise NotImplementedError(f"pose_rel2abs does not support: {abs_joint}")
        rel_rot_mat = torch.eye(3, device=full_pose.device, dtype=full_pose.dtype).unsqueeze_(dim=0)
        for idx in kin_chain:
            rel_rot_mat = torch.matmul(full_pose[:, idx], rel_rot_mat)
        abs_pose = rel_rot_mat[:, None, :, :]
        return abs_pose
